import numpy as np
import matplotlib.pyplot as plt
from SVMAgent import MLPAgent, FastComNetwork, ComNetwork, DatasetModel, MLP, MLPOracle
import random


# hyper-parameter
NUM_AGENTS = 16
NUM_ROUNDS = 200
T_RESTART = 20
LR = 0.01  # DOC²S learning rate
D = 0.01  # DOC²S radio
m_LR = 0.01  # MEDOL learning rate
m_D = 0.001  # MEDOL radio
R = 2  # Chebyshev rounds
p = 0.99  # control the spectrum gap
BATCH_SIZE = 128
DATASET_NAME = 'mnist'
HIDDEN_DIM = 256

random.seed(42)
np.random.seed(42)

# initialize dataset
dataset = DatasetModel(
    dsname=DATASET_NAME,
    num_agent=NUM_AGENTS,
    mb_size=BATCH_SIZE,
    max_sample=10000
)

# Oracle setting
oracle = MLPOracle(lam=1e-5, hidden_dim=HIDDEN_DIM)

# ring_matrix
def ring_matrix(n, p):
    W = np.zeros((n, n))
    for i in range(n):
        W[i, (i - 1) % n] = (1 - p) / 2
        W[i, i] = p
        W[i, (i + 1) % n] = (1 - p) / 2
    return W


def create_matrix(n):
    return np.full((n, n), 1 / n)

# DOC²S train loss
def train_DOC2S(agents, num_rounds, t_restart):
    network = FastComNetwork(create_matrix(NUM_AGENTS))
    losses = []

    avg_w = network.get_average_weight(agents)
    X_test, y_test = dataset.get_test_set()
    loss = oracle.get_fn_val(avg_w, X_test, y_test)
    losses.append(loss)
    print(f"round 0, {loss:.4f}")

    for k in range(num_rounds):
        if k % t_restart == (t_restart - 1):
            for agent in agents:
                agent.initialize_action()

        selected = np.random.randint(NUM_AGENTS)
        x_mb, y_mb = dataset.get_sample(selected)

        for agent in agents:
            agent.get_grad_point()

        # renew weight
        new_weight = agents[selected].DOC2S_get_new_weight()
        agents[selected].set_weight(new_weight)

        # get gradient
        grad_point = agents[selected].get_grad_points()
        grad = oracle.get_gradients(grad_point, x_mb, y_mb)

        # get action
        for i, agent in enumerate(agents):
            if i == selected:
                unprojected = agent.get_action() - agent.lr * grad
                norm = np.linalg.norm(unprojected)
                scale = min(1, agent.D / norm) if norm > 1e-8 else 1.0
                agent.set_action(agent.NUM_AGENTS * scale * unprojected)
            else:
                agent.set_action(np.zeros_like(agent.get_action()))

        # communication
        network.propagate_actions(agents, R)
        network.propagate_weights(agents, R)

        # printing results
        if k > 0 and (k + 1) % 10 == 0:
            avg_w = network.get_average_weight(agents)
            X_test, y_test = dataset.get_test_set()
            loss = oracle.get_fn_val(avg_w, X_test, y_test)
            losses.append(loss)
            print(f"round {10*(k + 1)}, loss: {loss:.4f}")

    return losses


# MEDOL train loss
def train_MEDOL(agents, num_rounds, t_restart):
    network = ComNetwork(ring_matrix(NUM_AGENTS, p))
    losses = []

    avg_w = network.get_average_weight(agents)
    X_test, y_test = dataset.get_test_set()
    loss = oracle.get_fn_val(avg_w, X_test, y_test)
    losses.append(loss)
    print(f"round 0, loss: {loss:.4f}")

    for k in range(num_rounds):
        if k % t_restart == (t_restart - 1):
            for agent in agents:
                agent.initialize_action()

        for m, agent in enumerate(agents):
            agent.get_grad_point()
            new_weight = agent.get_new_weight()
            agent.set_weight(new_weight)

            x_mb, y_mb = dataset.get_sample(m)
            grad_point = agent.get_grad_points()
            grad = oracle.get_gradients(grad_point, x_mb, y_mb)
            agent.action_grad_update(grad)

        network.propagate_actions(agents)
        network.propagate_weights(agents)

        # printing result
        if k > 0 and (k + 1) % 10 == 0:
            avg_w = network.get_average_weight(agents)
            X_test, y_test = dataset.get_test_set()
            loss = oracle.get_fn_val(avg_w, X_test, y_test)
            losses.append(loss)
            print(f"rounds {10*(k + 1)}, loss: {loss:.4f}")

    return losses

# initialize agent
doc2s_agents = [
    MLPAgent(
        input_dim=dataset.input_dim,
        hidden_dim=HIDDEN_DIM,
        id=i,
        lr=LR,
        D=D,
        NUM_AGENTS=NUM_AGENTS
    ) for i in range(NUM_AGENTS)
]

medol_agents = [
    MLPAgent(
        input_dim=dataset.input_dim,
        hidden_dim=HIDDEN_DIM,
        id=i,
        lr=m_LR,
        D=m_D,
        NUM_AGENTS=NUM_AGENTS
    ) for i in range(NUM_AGENTS)
]

initial_weights = doc2s_agents[0].get_weight().copy()
for i in range(NUM_AGENTS):
    doc2s_agents[i].set_weight(initial_weights)
    medol_agents[i].set_weight(initial_weights)

# get loss
print("DOC2S...")
doc2s_losses = train_DOC2S(doc2s_agents, NUM_ROUNDS, T_RESTART)

print("\nMEDOL...")
medol_losses = train_MEDOL(medol_agents, NUM_ROUNDS, T_RESTART)

# fig
plt.figure(figsize=(8.5, 7.5))

plt.rc('text', usetex=False)
plt.rc('font', family='sans-serif')
plt.rcParams['text.latex.preamble'] = r'\usepackage{sfmath}'

rounds_x_axis = [0] + list(range(100, 10*(NUM_ROUNDS + 1), 100))

plt.plot(rounds_x_axis, doc2s_losses, label=r"$\mathrm{DOC^2S}$",color="black", marker='o',markersize=11)
plt.plot(rounds_x_axis, medol_losses, label=r"$\mathrm{ME{-}DOL}$",color="red",marker='*',markersize=11)
plt.xlabel(r"$\mathrm{Computation~Rounds}$",fontsize=31)
plt.ylabel(r"$\mathrm{Function~value}$",fontsize=31)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=27)

plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()